{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "55h6JBJeJGqg"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YrOuxMeKJZxC"
      },
      "source": [
        "# Use Apache Beam and BigQuery to enrich data\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/bigquery_enrichment_transform.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pf2bL-PmJScZ"
      },
      "source": [
        "This notebook shows how to enrich data by using the Apache Beam [enrichment transform](https://beam.apache.org/documentation/transforms/python/elementwise/enrichment/) with [BigQuery](https://cloud.google.com/bigquery/docs/overview). The enrichment transform is an Apache Beam turnkey transform that lets you enrich data by using a key-value lookup. This transform has the following features:\n",
        "\n",
        "- The transform has a built-in Apache Beam handler that interacts with BigQuery data during enrichment.\n",
        "- The enrichment transform uses client-side throttling to rate limit the requests. The default retry strategy uses exponential backoff. You can configure rate limiting to suit your use case.\n",
        "\n",
        "This notebook demonstrates the following telecommunications company use case:\n",
        "\n",
        "A telecom company wants to predict which customers are likely to cancel their subscriptions so that the company can proactively offer these customers incentives to stay. The example uses customer demographic data and usage data stored in BigQuery to enrich a stream of customer IDs. The enriched data is then used to predict the likelihood of customer churn.\n",
        "\n",
        "## Before you begin\n",
        "Set up your environment and download dependencies.\n",
        "\n",
        "### Install Apache Beam\n",
        "To use the enrichment transform with the built-in BigQuery handler, install the Apache Beam SDK version 2.57.0 or later."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oVbWf73FJSzf"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "!pip install torch\n",
        "!pip install apache_beam[interactive,gcp]==2.57.0 --quiet"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "siSUsfR5tKX9"
      },
      "source": [
        "Import the following modules:\n",
        "- Pub/Sub for streaming data\n",
        "- BigQuery for enrichment\n",
        "- Apache Beam for running the streaming pipeline\n",
        "- PyTorch to predict customer churn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p6bruDqFJkXE"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "import datetime\n",
        "import json\n",
        "import math\n",
        "\n",
        "from typing import Any\n",
        "from typing import Dict\n",
        "\n",
        "import torch\n",
        "from google.cloud import pubsub_v1\n",
        "from google.cloud import bigquery\n",
        "from google.api_core.exceptions import Conflict\n",
        "\n",
        "import apache_beam as beam\n",
        "import apache_beam.runners.interactive.interactive_beam as ib\n",
        "from apache_beam.ml.inference.base import KeyedModelHandler\n",
        "from apache_beam.ml.inference.base import RunInference\n",
        "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor\n",
        "from apache_beam.options import pipeline_options\n",
        "from apache_beam.runners.interactive.interactive_runner import InteractiveRunner\n",
        "from apache_beam.transforms.enrichment import Enrichment\n",
        "from apache_beam.transforms.enrichment_handlers.bigquery import BigQueryEnrichmentHandler\n",
        "\n",
        "import pandas as pd\n",
        "\n",
        "from sklearn.preprocessing import LabelEncoder"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t0QfhuUlJozO"
      },
      "source": [
        "### Authenticate with Google Cloud\n",
        "This notebook reads data from Pub/Sub and BigQuery. To use your Google Cloud account, authenticate this notebook.\n",
        "To prepare for this step, replace `<PROJECT_ID>` with your Google Cloud project ID."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RwoBZjD1JwnD"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "PROJECT_ID = \"<PROJECT_ID>\" # @param {type:'string'}\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rVAyQxoeKflB"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "from google.colab import auth\n",
        "auth.authenticate_user(project_id=PROJECT_ID)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1vDwknoHKoa-"
      },
      "source": [
        "### Set up the BigQuery tables\n",
        "\n",
        "Create sample BigQuery tables for this notebook.\n",
        "\n",
        "- Replace `<DATASET_ID>` with the name of your BigQuery dataset. Only letters (uppercase or lowercase), numbers, and underscores are allowed.\n",
        "- If the dataset does not exist, a new dataset with this ID is created."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UxeGFqSJu-G6"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "DATASET_ID = \"<DATASET_ID>\" # @param {type:'string'}\n",
        "\n",
        "CUSTOMERS_TABLE_ID = f'{PROJECT_ID}.{DATASET_ID}.customers'\n",
        "USAGE_TABLE_ID = f'{PROJECT_ID}.{DATASET_ID}.usage'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gw4RfZavyfpo"
      },
      "source": [
        "Create customer and usage tables, and insert fake data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-QRZC4v0KipK"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "client = bigquery.Client(project=PROJECT_ID)\n",
        "\n",
        "# Create dataset if it does not exist.\n",
        "client.create_dataset(bigquery.Dataset(f\"{PROJECT_ID}.{DATASET_ID}\"), exists_ok=True)\n",
        "print(f\"Created dataset {DATASET_ID}\")\n",
        "\n",
        "# Prepare the fake customer data.\n",
        "customer_data = {\n",
        "    'customer_id': [1, 2, 3, 4, 5],\n",
        "    'age': [35, 28, 45, 62, 22],\n",
        "    'plan': ['Gold', 'Silver', 'Bronze', 'Gold', 'Silver'],\n",
        "    'contract_length': [12, 24, 6, 36, 12]\n",
        "}\n",
        "\n",
        "customers_df = pd.DataFrame(customer_data)\n",
        "\n",
        "# Insert customer data.\n",
        "job_config = bigquery.LoadJobConfig(\n",
        "    schema=[\n",
        "        bigquery.SchemaField(\"customer_id\", \"INTEGER\"),\n",
        "        bigquery.SchemaField(\"age\", \"INTEGER\"),\n",
        "        bigquery.SchemaField(\"plan\", \"STRING\"),\n",
        "        bigquery.SchemaField(\"contract_length\", \"INTEGER\"),\n",
        "    ],\n",
        "    write_disposition=\"WRITE_TRUNCATE\",\n",
        ")\n",
        "\n",
        "job = client.load_table_from_dataframe(\n",
        "    customers_df, CUSTOMERS_TABLE_ID, job_config=job_config\n",
        ")\n",
        "job.result()  # Wait for the job to complete.\n",
        "print(f\"Customers table created and populated: {CUSTOMERS_TABLE_ID}\")\n",
        "\n",
        "# Prepare the fake usage data.\n",
        "usage_data = {\n",
        "    'customer_id': [1, 1, 2, 2, 3, 3, 4, 4, 5, 5],\n",
        "    'date': pd.to_datetime(['2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01', '2024-09-01', '2024-10-01']),\n",
        "    'calls_made': [50, 65, 20, 18, 100, 110, 30, 28, 60, 70],\n",
        "    'data_usage_gb': [10, 12, 5, 4, 20, 22, 8, 7, 15, 18]\n",
        "}\n",
        "usage_df = pd.DataFrame(usage_data)\n",
        "\n",
        "# Insert usage data.\n",
        "job_config = bigquery.LoadJobConfig(\n",
        "    schema=[\n",
        "        bigquery.SchemaField(\"customer_id\", \"INTEGER\"),\n",
        "        bigquery.SchemaField(\"date\", \"DATE\"),\n",
        "        bigquery.SchemaField(\"calls_made\", \"INTEGER\"),\n",
        "        bigquery.SchemaField(\"data_usage_gb\", \"FLOAT\"),\n",
        "    ],\n",
        "    write_disposition=\"WRITE_TRUNCATE\",\n",
        ")\n",
        "job = client.load_table_from_dataframe(\n",
        "    usage_df, USAGE_TABLE_ID, job_config=job_config\n",
        ")\n",
        "job.result()  # Wait for the job to complete.\n",
        "\n",
        "print(f\"Usage table created and populated: {USAGE_TABLE_ID}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PZCjCzxaLOJt"
      },
      "source": [
        "### Train the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R4dIHclDLfIj"
      },
      "source": [
        "Create sample data and train a simple model for churn prediction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YoMjdqJ1KxOM"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "# Create fake training data\n",
        "data = {\n",
        "    'customer_id': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],\n",
        "    'age': [35, 28, 45, 62, 22, 38, 55, 25, 40, 30],\n",
        "    'plan': ['Gold', 'Silver', 'Bronze', 'Gold', 'Silver', 'Bronze', 'Gold', 'Silver', 'Bronze', 'Silver'],\n",
        "    'contract_length': [12, 24, 6, 36, 12, 18, 30, 12, 24, 18],\n",
        "    'avg_monthly_calls': [57.5, 19, 100, 30, 60, 45, 25, 70, 50, 35],\n",
        "    'avg_monthly_data_usage_gb': [11, 4.5, 20, 8, 15, 10, 7, 18, 12, 8],\n",
        "    'churned': [0, 0, 1, 0, 1, 0, 0, 1, 0, 1]  # Target variable\n",
        "}\n",
        "plan_encoder = LabelEncoder()\n",
        "plan_encoder.fit(data['plan'])\n",
        "df = pd.DataFrame(data)\n",
        "df['plan'] = plan_encoder.transform(data['plan'])\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EgIFJx76MF3v"
      },
      "source": [
        "Preprocess the data:\n",
        "\n",
        "1.   Convert the lists to tensors.\n",
        "2.   Separate the features from the expected prediction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P-8lKzdzLnGo"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "features = ['age', 'plan', 'contract_length', 'avg_monthly_calls', 'avg_monthly_data_usage_gb']\n",
        "target = 'churned'\n",
        "\n",
        "X = torch.tensor(df[features].values, dtype=torch.float)\n",
        "Y = torch.tensor(df[target], dtype=torch.float)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4mcNOez1MQZP"
      },
      "source": [
        "Define a model that has five input features and predicts a single value."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YvdPNlzoMTtl"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "def build_model(n_inputs, n_outputs):\n",
        "  \"\"\"build_model builds and returns a model that takes\n",
        "  `n_inputs` features and predicts `n_outputs` value\"\"\"\n",
        "  return torch.nn.Sequential(\n",
        "      torch.nn.Linear(n_inputs, 8),\n",
        "      torch.nn.ReLU(),\n",
        "      torch.nn.Linear(8, 16),\n",
        "      torch.nn.ReLU(),\n",
        "      torch.nn.Linear(16, n_outputs),\n",
        "      torch.nn.Sigmoid())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GaLBmcvrMOWy"
      },
      "source": [
        "Train the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0XqctMiPMaim"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "model = build_model(n_inputs=5, n_outputs=1)\n",
        "\n",
        "loss_fn = torch.nn.BCELoss()\n",
        "optimizer = torch.optim.Adam(model.parameters())\n",
        "\n",
        "for epoch in range(1000):\n",
        "  print(f'Epoch {epoch}: ---')\n",
        "  optimizer.zero_grad()\n",
        "  for i in range(len(X)):\n",
        "    pred = model(X[i])\n",
        "    loss = loss_fn(pred, Y[i].unsqueeze(0))\n",
        "    loss.backward()\n",
        "  optimizer.step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m7MD6RwGMdyU"
      },
      "source": [
        "Save the model to the `STATE_DICT_PATH` variable."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Q9WIjw53MgcR"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "STATE_DICT_PATH = './model.pth'\n",
        "torch.save(model.state_dict(), STATE_DICT_PATH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CJVYA0N0MnZS"
      },
      "source": [
        "### Publish messages to Pub/Sub\n",
        "Create the Pub/Sub topic and subscription to use for data streaming."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0uwZz_ijyzL8"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "# Replace <TOPIC_NAME> with the name of your Pub/Sub topic.\n",
        "TOPIC = \"<TOPIC_NAME>\" # @param {type:'string'}\n",
        "\n",
        "# Replace <SUBSCRIPTION_PATH> with the subscription for your topic.\n",
        "SUBSCRIPTION = \"<SUBSCRIPTION_PATH>\" # @param {type:'string'}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hIgsCWIozdDu"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "from google.api_core.exceptions import AlreadyExists\n",
        "\n",
        "publisher = pubsub_v1.PublisherClient()\n",
        "topic_path = publisher.topic_path(PROJECT_ID, TOPIC)\n",
        "try:\n",
        "  topic = publisher.create_topic(request={\"name\": topic_path})\n",
        "  print(f\"Created topic: {topic.name}\")\n",
        "except AlreadyExists:\n",
        "  print(f\"Topic {topic_path} already exists.\")\n",
        "\n",
        "subscriber = pubsub_v1.SubscriberClient()\n",
        "subscription_path = subscriber.subscription_path(PROJECT_ID, SUBSCRIPTION)\n",
        "try:\n",
        "    subscription = subscriber.create_subscription(\n",
        "        request={\"name\": subscription_path, \"topic\": topic_path}\n",
        "    )\n",
        "    print(f\"Created subscription: {subscription.name}\")\n",
        "except AlreadyExists:\n",
        "  print(f\"Subscription {subscription_path} already exists.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VqUaFm_yywjU"
      },
      "source": [
        "\n",
        "Use the Pub/Sub Python client to publish messages."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fOq1uNXvMku-"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "messages = [\n",
        "    {'customer_id': i}\n",
        "    for i in range(1, 6)\n",
        "]\n",
        "\n",
        "for message in messages:\n",
        "  data = json.dumps(message).encode('utf-8')\n",
        "  publish_future = publisher.publish(topic_path, data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "giXOGruKM8ZL"
      },
      "source": [
        "## Use the BigQuery enrichment handler\n",
        "\n",
        "The [`BigQueryEnrichmentHandler`](https://beam.apache.org/releases/pydoc/current/apache_beam.transforms.enrichment_handlers.bigquery.html#apache_beam.transforms.enrichment_handlers.bigquery.BigQueryEnrichmentHandler) is a built-in handler included in the Apache Beam SDK versions 2.57.0 and later.\n",
        "\n",
        "Configure the `BigQueryEnrichmentHandler` handler with the following parameters.\n",
        "\n",
        "### Required parameters\n",
        "\n",
        "The following parameters are required.\n",
        "\n",
        "* `project` (str): The Google Cloud project ID for the BigQuery table\n",
        "\n",
        "You must also provide one of the following combinations:\n",
        "* `table_name`, `row_restriction_template`, and `fields`\n",
        "* `table_name`, `row_restriction_template`, and `condition_value_fn`\n",
        "* `query_fn`\n",
        "\n",
        "### Optional parameters\n",
        "\n",
        "The following parameters are optional.\n",
        "\n",
        "* `table_name` (str): The fully qualified BigQuery table name in the format `project.dataset.table`\n",
        "* `row_restriction_template` (str): A template string for the `WHERE` clause in the BigQuery query with placeholders (`{}`) to dynamically filter rows based on input data\n",
        "* `fields` (Optional[List[str]]): A list of field names present in the input `beam.Row`. These fields names are used to construct the `WHERE` clause if `condition_value_fn` is not provided.\n",
        "* `column_names` (Optional[List[str]]): The names of columns to select from the BigQuery table. If not provided, all columns (`*`) are selected.\n",
        "* `condition_value_fn` (Optional[Callable[[beam.Row], List[Any]]]): A function that takes a `beam.Row` and returns a list of values to populate in the placeholder `{}` of the `WHERE` clause in the query\n",
        "* `query_fn` (Optional[Callable[[beam.Row], str]]): A function that takes a `beam.Row` and returns a complete BigQuery SQL query string\n",
        "* `min_batch_size` (int): The minimum number of rows to batch together when querying BigQuery. Defaults to `1` if `query_fn` is not specified.\n",
        "* `max_batch_size` (int): The maximum number of rows to batch together. Defaults to `10,000` if `query_fn` is not specified.\n",
        "\n",
        "### Parameter requirements\n",
        "\n",
        "When you use parameters, consider the following requirements.\n",
        "\n",
        "* You can't define the `min_batch_size` and `max_batch_size` parameters if you provide the `query_fn` parameter.\n",
        "* You must provide either the `fields` parameter or the `condition_value_fn` parameter for query construction if you don't provide the `query_fn` parameter.\n",
        "* You must grant the appropriate permissions to access BigQuery.\n",
        "\n",
        "### Create handlers\n",
        "\n",
        "In this example, you create two handlers:\n",
        "\n",
        "* One for customer data that specifies `table_name` and `row_restriction_template`\n",
        "* One for usage data that uses a custom aggregation query by using the `query_fn` function\n",
        "\n",
        "These handlers are used in the Enrichment transforms in this pipeline to fetch and join data from BigQuery with the streaming data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C8XLmBDeMyrB"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "user_data_handler = BigQueryEnrichmentHandler(\n",
        "    project=PROJECT_ID,\n",
        "    table_name=f\"`{CUSTOMERS_TABLE_ID}`\",\n",
        "    row_restriction_template='customer_id = {}',\n",
        "    fields=['customer_id']\n",
        ")\n",
        "\n",
        "# Define the SQL query for usage data aggregation.\n",
        "usage_data_query_template = f\"\"\"\n",
        "WITH monthly_aggregates AS (\n",
        "  SELECT\n",
        "    customer_id,\n",
        "    DATE_TRUNC(date, MONTH) as month,\n",
        "    SUM(calls_made) as total_calls,\n",
        "    SUM(data_usage_gb) as total_data_usage_gb\n",
        "  FROM\n",
        "    `{USAGE_TABLE_ID}`\n",
        "  WHERE\n",
        "    customer_id = @customer_id\n",
        "  GROUP BY\n",
        "    customer_id, month\n",
        ")\n",
        "SELECT\n",
        "  customer_id,\n",
        "  AVG(total_calls) as avg_monthly_calls,\n",
        "  AVG(total_data_usage_gb) as avg_monthly_data_usage_gb\n",
        "FROM\n",
        "  monthly_aggregates\n",
        "GROUP BY\n",
        "  customer_id\n",
        "\"\"\"\n",
        "\n",
        "def usage_data_query_fn(row: beam.Row) -> str:\n",
        "    return usage_data_query_template.replace('@customer_id', str(row.customer_id))\n",
        "\n",
        "usage_data_handler = BigQueryEnrichmentHandler(\n",
        "    project=PROJECT_ID,\n",
        "    query_fn=usage_data_query_fn\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3oPYypvmPiyg"
      },
      "source": [
        "In this example:\n",
        "1. The `user_data_handler` handler uses the `table_name`, `row_restriction_template`, and `fields` parameter combination to fetch customer data.\n",
        "2. The `usage_data_handler` handler uses the `query_fn` parameter to execute a more complex query that aggregates usage data."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ksON9uOBQbZm"
      },
      "source": [
        "## Use the `PytorchModelHandlerTensor` interface to run inference\n",
        "\n",
        "Define functions to convert enriched data to the tensor format for the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XgPontIVP0Cv"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "def convert_row_to_tensor(customer_data):\n",
        "    import pandas as pd\n",
        "    customer_df = pd.DataFrame([customer_data[1].as_dict()])\n",
        "    customer_df['plan'] = plan_encoder.transform(customer_df['plan'])\n",
        "    return (customer_data[0], torch.tensor(customer_df[features].values, dtype=torch.float))\n",
        "\n",
        "keyed_model_handler = KeyedModelHandler(PytorchModelHandlerTensor(\n",
        "    state_dict_path=STATE_DICT_PATH,\n",
        "    model_class=build_model,\n",
        "    model_params={'n_inputs':5, 'n_outputs':1}\n",
        ")).with_preprocess_fn(convert_row_to_tensor)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O9e7ddgGQxh2"
      },
      "source": [
        "Define a `DoFn` to format the output."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NMj0V5VyQukk"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "class PostProcessor(beam.DoFn):\n",
        "  def process(self, element, *args, **kwargs):\n",
        "    print('Customer %d churn risk: %s' % (element[0], \"High\" if element[1].inference[0].item() > 0.5 else \"Low\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-N3a1s2FQ66z"
      },
      "source": [
        "## Run the pipeline\n",
        "\n",
        "Configure the pipeline to run in streaming mode."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rgJeV-jWQ4wo"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "options = pipeline_options.PipelineOptions()\n",
        "options.view_as(pipeline_options.StandardOptions).streaming = True # Streaming mode is set True"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NRljYVR5RCMi"
      },
      "source": [
        "Pub/Sub sends the data in bytes. Convert the data to `beam.Row` objects by using a `DoFn`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bb-e3yjtQ2iU"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "class DecodeBytes(beam.DoFn):\n",
        "  \"\"\"\n",
        "  The DecodeBytes `DoFn` converts the data read from Pub/Sub to `beam.Row`.\n",
        "  First, decode the encoded string. Convert the output to\n",
        "  a `dict` with `json.loads()`, which is used to create a `beam.Row`.\n",
        "  \"\"\"\n",
        "  def process(self, element, *args, **kwargs):\n",
        "    element_dict = json.loads(element.decode('utf-8'))\n",
        "    yield beam.Row(**element_dict)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q1HV8wH-RIbj"
      },
      "source": [
        "Use the following code to run the pipeline.\n",
        "\n",
        "**Note:** Because this pipeline is a streaming pipeline, you need to manually stop the cell. If you don't stop the cell, the pipeline continues to run."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y6HBH8yoRFp2"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "with beam.Pipeline(options=options) as p:\n",
        "  _ = (p\n",
        "       | \"Read from Pub/Sub\" >> beam.io.ReadFromPubSub(subscription=f\"projects/{PROJECT_ID}/subscriptions/{SUBSCRIPTION}\")\n",
        "       | \"ConvertToRow\" >> beam.ParDo(DecodeBytes())\n",
        "       | \"Enrich with customer data\" >> Enrichment(user_data_handler)\n",
        "       | \"Enrich with usage data\" >> Enrichment(usage_data_handler)\n",
        "       | \"Key data\" >> beam.Map(lambda x: (x.customer_id, x))\n",
        "       | \"RunInference\" >> RunInference(keyed_model_handler)\n",
        "       | \"Format Output\" >> beam.ParDo(PostProcessor())\n",
        "       )"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "include_colab_link": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
